{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.02095911, -0.04298958,  0.04404614, -0.0466527 ], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, done, _, info = self.env.step(action)\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wjRvJ8zBftL3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<__main__.CustomCallback at 0x7f53a8ae42b0>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.callbacks import BaseCallback\n",
    "\n",
    "\n",
    "#Callback语法\n",
    "class CustomCallback(BaseCallback):\n",
    "\n",
    "    def __init__(self, verbose=0):\n",
    "        super().__init__(verbose)\n",
    "\n",
    "        #可以访问的变量\n",
    "        #self.model\n",
    "        #self.training_env\n",
    "        #self.n_calls\n",
    "        #self.num_timesteps\n",
    "        #self.locals\n",
    "        #self.globals\n",
    "        #self.logger\n",
    "        #self.parent\n",
    "\n",
    "    def _on_training_start(self) -> None:\n",
    "        #第一个rollout开始前调用\n",
    "        pass\n",
    "\n",
    "    def _on_rollout_start(self) -> None:\n",
    "        #rollout开始前\n",
    "        pass\n",
    "\n",
    "    def _on_step(self) -> bool:\n",
    "        #env.step()之后调用,返回False后停止训练\n",
    "        return True\n",
    "\n",
    "    def _on_rollout_end(self) -> None:\n",
    "        #更新参数前调用\n",
    "        pass\n",
    "\n",
    "    def _on_training_end(self) -> None:\n",
    "        #训练结束前调用\n",
    "        pass\n",
    "\n",
    "\n",
    "CustomCallback()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7ILY0AkFfzPJ"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20\n",
      "40\n",
      "60\n",
      "80\n",
      "100\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<stable_baselines3.ppo.ppo.PPO at 0x7f53a8ae4430>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3 import PPO\n",
    "\n",
    "\n",
    "#让训练只执行N步的callback\n",
    "class SimpleCallback(BaseCallback):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__(verbose=0)\n",
    "        self.call_count = 0\n",
    "\n",
    "    def _on_step(self):\n",
    "        self.call_count += 1\n",
    "\n",
    "        if self.call_count % 20 == 0:\n",
    "            print(self.call_count)\n",
    "\n",
    "        if self.call_count >= 100:\n",
    "            return False\n",
    "\n",
    "        return True\n",
    "\n",
    "\n",
    "model = PPO('MlpPolicy', MyWrapper(), verbose=0)\n",
    "\n",
    "model.learn(8000, callback=SimpleCallback())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(1267.75, 845.1674316370692)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.env_util import make_vec_env\n",
    "from stable_baselines3 import A2C\n",
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "import gym\n",
    "\n",
    "\n",
    "def test_callback(callback):\n",
    "\n",
    "    #创建Monitor封装的环境,这会在训练过程中写出日志文件到models文件夹\n",
    "    env = make_vec_env(MyWrapper, n_envs=1, monitor_dir='models')\n",
    "\n",
    "    #等价写法\n",
    "    # from stable_baselines3.common.monitor import Monitor\n",
    "    # from stable_baselines3.common.vec_env import DummyVecEnv\n",
    "    # env = Monitor(MyWrapper(), 'models')\n",
    "    # env = DummyVecEnv([lambda: env])\n",
    "\n",
    "    #训练\n",
    "    model = A2C('MlpPolicy', env, verbose=0).learn(total_timesteps=5000,\n",
    "                                                   callback=callback)\n",
    "\n",
    "    #测试\n",
    "    return evaluate_policy(model, MyWrapper(), n_eval_episodes=20)\n",
    "\n",
    "\n",
    "#使用Monitor封装的环境训练一个模型,保存下日志\n",
    "#只是为了测试load_results, ts2xy这两个函数\n",
    "test_callback(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>r</th>\n",
       "      <th>l</th>\n",
       "      <th>t</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>32.0</td>\n",
       "      <td>32</td>\n",
       "      <td>0.049117</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>20.0</td>\n",
       "      <td>20</td>\n",
       "      <td>0.071115</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>23.0</td>\n",
       "      <td>23</td>\n",
       "      <td>0.097009</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>27.0</td>\n",
       "      <td>27</td>\n",
       "      <td>0.128370</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>41.0</td>\n",
       "      <td>41</td>\n",
       "      <td>0.173688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>79</td>\n",
       "      <td>88.0</td>\n",
       "      <td>88</td>\n",
       "      <td>5.220903</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80</th>\n",
       "      <td>80</td>\n",
       "      <td>80.0</td>\n",
       "      <td>80</td>\n",
       "      <td>5.312054</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>81</th>\n",
       "      <td>81</td>\n",
       "      <td>82.0</td>\n",
       "      <td>82</td>\n",
       "      <td>5.447958</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>82</th>\n",
       "      <td>82</td>\n",
       "      <td>121.0</td>\n",
       "      <td>121</td>\n",
       "      <td>5.603233</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>83</th>\n",
       "      <td>83</td>\n",
       "      <td>238.0</td>\n",
       "      <td>238</td>\n",
       "      <td>5.901814</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>84 rows × 4 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    index      r    l         t\n",
       "0       0   32.0   32  0.049117\n",
       "1       1   20.0   20  0.071115\n",
       "2       2   23.0   23  0.097009\n",
       "3       3   27.0   27  0.128370\n",
       "4       4   41.0   41  0.173688\n",
       "..    ...    ...  ...       ...\n",
       "79     79   88.0   88  5.220903\n",
       "80     80   80.0   80  5.312054\n",
       "81     81   82.0   82  5.447958\n",
       "82     82  121.0  121  5.603233\n",
       "83     83  238.0  238  5.901814\n",
       "\n",
       "[84 rows x 4 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from stable_baselines3.common.results_plotter import load_results, ts2xy\n",
    "\n",
    "#加载日志,这里找的是models/*.monitor.csv\n",
    "load_results('models')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([  32,   52,   75,  102,  143,  153,  169,  183,  200,  212,  227,\n",
       "         252,  283,  317,  358,  377,  399,  421,  445,  505,  537,  627,\n",
       "         674,  725,  753,  808,  852,  900,  989, 1053, 1117, 1203, 1251,\n",
       "        1304, 1320, 1424, 1458, 1483, 1604, 1679, 1711, 1769, 1851, 1890,\n",
       "        1960, 2048, 2159, 2204, 2324, 2387, 2424, 2476, 2541, 2564, 2587,\n",
       "        2643, 2678, 2703, 2769, 2823, 2952, 3010, 3040, 3129, 3294, 3392,\n",
       "        3420, 3562, 3699, 3740, 3987, 4057, 4087, 4151, 4261, 4302, 4331,\n",
       "        4363, 4385, 4473, 4553, 4635, 4756, 4994]),\n",
       " array([ 32.,  20.,  23.,  27.,  41.,  10.,  16.,  14.,  17.,  12.,  15.,\n",
       "         25.,  31.,  34.,  41.,  19.,  22.,  22.,  24.,  60.,  32.,  90.,\n",
       "         47.,  51.,  28.,  55.,  44.,  48.,  89.,  64.,  64.,  86.,  48.,\n",
       "         53.,  16., 104.,  34.,  25., 121.,  75.,  32.,  58.,  82.,  39.,\n",
       "         70.,  88., 111.,  45., 120.,  63.,  37.,  52.,  65.,  23.,  23.,\n",
       "         56.,  35.,  25.,  66.,  54., 129.,  58.,  30.,  89., 165.,  98.,\n",
       "         28., 142., 137.,  41., 247.,  70.,  30.,  64., 110.,  41.,  29.,\n",
       "         32.,  22.,  88.,  80.,  82., 121., 238.]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ts2xy(load_results('models'), 'timesteps')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nzMHj7r3h78m"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000 -inf 37.23076923076923\n",
      "save 968\n",
      "2000 37.23076923076923 45.86046511627907\n",
      "save 1972\n",
      "3000 45.86046511627907 54.870370370370374\n",
      "save 2963\n",
      "4000 54.870370370370374 58.705882352941174\n",
      "save 3992\n",
      "5000 58.705882352941174 65.13333333333334\n",
      "save 4885\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(175.35, 34.45616780781055)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#保存最优模型\n",
    "class SaveOnBestTrainingRewardCallback(BaseCallback):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__(verbose=0)\n",
    "\n",
    "        self.best = -float('inf')\n",
    "\n",
    "    def _on_step(self):\n",
    "        #self.n_calls是个从1开始的数\n",
    "        if self.n_calls % 1000 != 0:\n",
    "            return True\n",
    "\n",
    "        #读取日志\n",
    "        x, y = ts2xy(load_results('models'), 'timesteps')\n",
    "\n",
    "        #求最后100个reward的均值\n",
    "        mean_reward = sum(y[-100:]) / len(y[-100:])\n",
    "\n",
    "        print(self.num_timesteps, self.best, mean_reward)\n",
    "\n",
    "        #判断保存\n",
    "        if mean_reward > self.best:\n",
    "            self.best = mean_reward\n",
    "            print('save', x[-1])\n",
    "            self.model.save('models/best_model')\n",
    "\n",
    "        return True\n",
    "\n",
    "\n",
    "test_callback(SaveOnBestTrainingRewardCallback())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "c0Bu1HWKRUn4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n",
      "x= [ 50  80 110 136 167 204 240 270 295 322 391 515 601 620 662 699 738 764\n",
      " 805 838 873 904 929 954 979]\n",
      "y= [ 50.  30.  30.  26.  31.  37.  36.  30.  25.  27.  69. 124.  86.  19.\n",
      "  42.  37.  39.  26.  41.  33.  35.  31.  25.  25.  25.]\n",
      "2000\n",
      "x= [  50   80  110  136  167  204  240  270  295  322  391  515  601  620\n",
      "  662  699  738  764  805  838  873  904  929  954  979 1021 1042 1058\n",
      " 1080 1108 1157 1210 1234 1264 1297 1337 1397 1451 1468 1523 1561 1591\n",
      " 1632 1692 1752 1781 1814 1850 1898 1918 1959]\n",
      "y= [ 50.  30.  30.  26.  31.  37.  36.  30.  25.  27.  69. 124.  86.  19.\n",
      "  42.  37.  39.  26.  41.  33.  35.  31.  25.  25.  25.  42.  21.  16.\n",
      "  22.  28.  49.  53.  24.  30.  33.  40.  60.  54.  17.  55.  38.  30.\n",
      "  41.  60.  60.  29.  33.  36.  48.  20.  41.]\n",
      "3000\n",
      "x= [  50   80  110  136  167  204  240  270  295  322  391  515  601  620\n",
      "  662  699  738  764  805  838  873  904  929  954  979 1021 1042 1058\n",
      " 1080 1108 1157 1210 1234 1264 1297 1337 1397 1451 1468 1523 1561 1591\n",
      " 1632 1692 1752 1781 1814 1850 1898 1918 1959 2014 2064 2094 2138 2194\n",
      " 2227 2264 2308 2351 2387 2431 2475 2508 2558 2765 2802 2842 2889 2999]\n",
      "y= [ 50.  30.  30.  26.  31.  37.  36.  30.  25.  27.  69. 124.  86.  19.\n",
      "  42.  37.  39.  26.  41.  33.  35.  31.  25.  25.  25.  42.  21.  16.\n",
      "  22.  28.  49.  53.  24.  30.  33.  40.  60.  54.  17.  55.  38.  30.\n",
      "  41.  60.  60.  29.  33.  36.  48.  20.  41.  55.  50.  30.  44.  56.\n",
      "  33.  37.  44.  43.  36.  44.  44.  33.  50. 207.  37.  40.  47. 110.]\n",
      "4000\n",
      "x= [  50   80  110  136  167  204  240  270  295  322  391  515  601  620\n",
      "  662  699  738  764  805  838  873  904  929  954  979 1021 1042 1058\n",
      " 1080 1108 1157 1210 1234 1264 1297 1337 1397 1451 1468 1523 1561 1591\n",
      " 1632 1692 1752 1781 1814 1850 1898 1918 1959 2014 2064 2094 2138 2194\n",
      " 2227 2264 2308 2351 2387 2431 2475 2508 2558 2765 2802 2842 2889 2999\n",
      " 3088 3197 3311 3359 3567 3798 3924]\n",
      "y= [ 50.  30.  30.  26.  31.  37.  36.  30.  25.  27.  69. 124.  86.  19.\n",
      "  42.  37.  39.  26.  41.  33.  35.  31.  25.  25.  25.  42.  21.  16.\n",
      "  22.  28.  49.  53.  24.  30.  33.  40.  60.  54.  17.  55.  38.  30.\n",
      "  41.  60.  60.  29.  33.  36.  48.  20.  41.  55.  50.  30.  44.  56.\n",
      "  33.  37.  44.  43.  36.  44.  44.  33.  50. 207.  37.  40.  47. 110.\n",
      "  89. 109. 114.  48. 208. 231. 126.]\n",
      "5000\n",
      "x= [  50   80  110  136  167  204  240  270  295  322  391  515  601  620\n",
      "  662  699  738  764  805  838  873  904  929  954  979 1021 1042 1058\n",
      " 1080 1108 1157 1210 1234 1264 1297 1337 1397 1451 1468 1523 1561 1591\n",
      " 1632 1692 1752 1781 1814 1850 1898 1918 1959 2014 2064 2094 2138 2194\n",
      " 2227 2264 2308 2351 2387 2431 2475 2508 2558 2765 2802 2842 2889 2999\n",
      " 3088 3197 3311 3359 3567 3798 3924 4040 4170 4319 4470 4600 4727 4848\n",
      " 4976]\n",
      "y= [ 50.  30.  30.  26.  31.  37.  36.  30.  25.  27.  69. 124.  86.  19.\n",
      "  42.  37.  39.  26.  41.  33.  35.  31.  25.  25.  25.  42.  21.  16.\n",
      "  22.  28.  49.  53.  24.  30.  33.  40.  60.  54.  17.  55.  38.  30.\n",
      "  41.  60.  60.  29.  33.  36.  48.  20.  41.  55.  50.  30.  44.  56.\n",
      "  33.  37.  44.  43.  36.  44.  44.  33.  50. 207.  37.  40.  47. 110.\n",
      "  89. 109. 114.  48. 208. 231. 126. 116. 130. 149. 151. 130. 127. 121.\n",
      " 128.]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(101.4, 16.73439571660716)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#可以打印或者画图的callback\n",
    "class PlottingCallback(BaseCallback):\n",
    "\n",
    "    def __init__(self, verbose=0):\n",
    "        super().__init__(verbose=0)\n",
    "\n",
    "    def _on_step(self) -> bool:\n",
    "        if self.n_calls % 1000 != 0:\n",
    "            return True\n",
    "\n",
    "        x, y = ts2xy(load_results('models'), 'timesteps')\n",
    "        print(self.n_calls)\n",
    "        print('x=', x)\n",
    "        print('y=', y)\n",
    "\n",
    "        return True\n",
    "\n",
    "\n",
    "test_callback(PlottingCallback())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pXa8f6FsRUn8"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9da8374173dd4ad68c8b82edd87e2b4e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "(175.3, 13.849548729110273)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "#更新进度条的callback\n",
    "class ProgressBarCallback(BaseCallback):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.pbar = tqdm(total=5000)\n",
    "\n",
    "    def _on_step(self):\n",
    "        self.pbar.update(1)\n",
    "\n",
    "    def _on_training_end(self) -> None:\n",
    "        self.pbar.close()\n",
    "\n",
    "\n",
    "test_callback(ProgressBarCallback())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "174a72d6dbc747e0993d4927ddc66b03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/5000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000\n",
      "x= [ 28  55  68  77  95 138 159 196 268 279 310 341 363 405 439 462 492 546\n",
      " 600 657 722 765 851 882 930 969]\n",
      "y= [28. 27. 13.  9. 18. 43. 21. 37. 72. 11. 31. 31. 22. 42. 34. 23. 30. 54.\n",
      " 54. 57. 65. 43. 86. 31. 48. 39.]\n",
      "2000\n",
      "x= [  28   55   68   77   95  138  159  196  268  279  310  341  363  405\n",
      "  439  462  492  546  600  657  722  765  851  882  930  969 1038 1088\n",
      " 1133 1177 1209 1245 1308 1362 1506 1613 1675 1757 1954]\n",
      "y= [ 28.  27.  13.   9.  18.  43.  21.  37.  72.  11.  31.  31.  22.  42.\n",
      "  34.  23.  30.  54.  54.  57.  65.  43.  86.  31.  48.  39.  69.  50.\n",
      "  45.  44.  32.  36.  63.  54. 144. 107.  62.  82. 197.]\n",
      "3000\n",
      "x= [  28   55   68   77   95  138  159  196  268  279  310  341  363  405\n",
      "  439  462  492  546  600  657  722  765  851  882  930  969 1038 1088\n",
      " 1133 1177 1209 1245 1308 1362 1506 1613 1675 1757 1954 2071 2231 2305\n",
      " 2448 2675 2739 2874]\n",
      "y= [ 28.  27.  13.   9.  18.  43.  21.  37.  72.  11.  31.  31.  22.  42.\n",
      "  34.  23.  30.  54.  54.  57.  65.  43.  86.  31.  48.  39.  69.  50.\n",
      "  45.  44.  32.  36.  63.  54. 144. 107.  62.  82. 197. 117. 160.  74.\n",
      " 143. 227.  64. 135.]\n",
      "4000\n",
      "x= [  28   55   68   77   95  138  159  196  268  279  310  341  363  405\n",
      "  439  462  492  546  600  657  722  765  851  882  930  969 1038 1088\n",
      " 1133 1177 1209 1245 1308 1362 1506 1613 1675 1757 1954 2071 2231 2305\n",
      " 2448 2675 2739 2874 3017 3180 3479 3686 3834 3922]\n",
      "y= [ 28.  27.  13.   9.  18.  43.  21.  37.  72.  11.  31.  31.  22.  42.\n",
      "  34.  23.  30.  54.  54.  57.  65.  43.  86.  31.  48.  39.  69.  50.\n",
      "  45.  44.  32.  36.  63.  54. 144. 107.  62.  82. 197. 117. 160.  74.\n",
      " 143. 227.  64. 135. 143. 163. 299. 207. 148.  88.]\n",
      "5000\n",
      "x= [  28   55   68   77   95  138  159  196  268  279  310  341  363  405\n",
      "  439  462  492  546  600  657  722  765  851  882  930  969 1038 1088\n",
      " 1133 1177 1209 1245 1308 1362 1506 1613 1675 1757 1954 2071 2231 2305\n",
      " 2448 2675 2739 2874 3017 3180 3479 3686 3834 3922 4130 4280 4406 4507\n",
      " 4643 4846]\n",
      "y= [ 28.  27.  13.   9.  18.  43.  21.  37.  72.  11.  31.  31.  22.  42.\n",
      "  34.  23.  30.  54.  54.  57.  65.  43.  86.  31.  48.  39.  69.  50.\n",
      "  45.  44.  32.  36.  63.  54. 144. 107.  62.  82. 197. 117. 160.  74.\n",
      " 143. 227.  64. 135. 143. 163. 299. 207. 148.  88. 208. 150. 126. 101.\n",
      " 136. 203.]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt39/lib/python3.9/site-packages/stable_baselines3/common/evaluation.py:67: UserWarning: Evaluation environment is not wrapped with a ``Monitor`` wrapper. This may result in reporting modified episode lengths and rewards, if other wrappers happen to modify these. Consider wrapping environment first with ``Monitor`` wrapper.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(737.35, 255.4991731884861)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#同时使用多个callback\n",
    "test_callback([PlottingCallback(), ProgressBarCallback()])"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "4_callbacks_hyperparameter_tuning.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
